Transforms for parsing lines from a CSV file and serialized tensorflow.Examples.
Change: 123426407
This commit is contained in:
parent
6164d02144
commit
c246c37b1b
@ -125,6 +125,30 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_csv_parser",
|
||||
size = "small",
|
||||
srcs = ["python/learn/tests/dataframe/test_csv_parser.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_example_parser",
|
||||
size = "small",
|
||||
srcs = ["python/learn/tests/dataframe/test_example_parser.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_early_stopping",
|
||||
size = "medium",
|
||||
|
@ -0,0 +1,67 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A Transform that parses lines from a CSV file."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
|
||||
|
||||
class CSVParser(transform.Transform):
|
||||
"""A Transform that parses lines from a CSV file."""
|
||||
|
||||
def __init__(self, column_names, default_values):
|
||||
"""Initialize `CSVParser`.
|
||||
|
||||
Args:
|
||||
column_names: a list of strings containing the names of columns to be
|
||||
output by the parser.
|
||||
default_values: a list containing each column.
|
||||
"""
|
||||
super(CSVParser, self).__init__()
|
||||
self._column_names = tuple(column_names)
|
||||
self._default_values = default_values
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "CSVParser"
|
||||
|
||||
@property
|
||||
def input_valency(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def _output_names(self):
|
||||
return self.column_names
|
||||
|
||||
@transform.parameter
|
||||
def column_names(self):
|
||||
return self._column_names
|
||||
|
||||
@transform.parameter
|
||||
def default_values(self):
|
||||
return self._default_values
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
default_consts = [constant_op.constant(d, shape=[1])
|
||||
for d in self._default_values]
|
||||
parsed_values = parsing_ops.decode_csv(input_tensors[0],
|
||||
record_defaults=default_consts)
|
||||
# pylint: disable=not-callable
|
||||
return self.return_type(*parsed_values)
|
@ -0,0 +1,68 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A Transform that parses serialized tensorflow.Example protos."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import transform
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
|
||||
|
||||
class ExampleParser(transform.Transform):
|
||||
"""A Transform that parses serialized `tensorflow.Example` protos."""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Initialize `ExampleParser`.
|
||||
|
||||
The `features` argument must be an object that can be converted to an
|
||||
`OrderedDict`. The keys should be strings and will be used to name the
|
||||
output. Values should be either `VarLenFeature` or `FixedLenFeature`. If
|
||||
`features` is a dict, it will be sorted by key.
|
||||
Args:
|
||||
features: An object that can be converted to an `OrderedDict` mapping
|
||||
column names to feature definitions.
|
||||
"""
|
||||
super(ExampleParser, self).__init__()
|
||||
if isinstance(features, dict):
|
||||
self._ordered_features = collections.OrderedDict(sorted(features.items(
|
||||
), key=lambda f: f[0]))
|
||||
else:
|
||||
self._ordered_features = collections.OrderedDict(features)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "ExampleParser"
|
||||
|
||||
@property
|
||||
def input_valency(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def _output_names(self):
|
||||
return list(self._ordered_features.keys())
|
||||
|
||||
@transform.parameter
|
||||
def feature_definitions(self):
|
||||
return self._ordered_features
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
parsed_values = parsing_ops.parse_example(input_tensors[0],
|
||||
features=self._ordered_features)
|
||||
# pylint: disable=not-callable
|
||||
return self.return_type(**parsed_values)
|
@ -0,0 +1,50 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for learn.python.learn.dataframe.transforms.csv_parser."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import csv_parser
|
||||
from tensorflow.contrib.learn.python.learn.tests.dataframe import mocks
|
||||
|
||||
|
||||
class CSVParserTestCase(tf.test.TestCase):
|
||||
|
||||
def testParse(self):
|
||||
parser = csv_parser.CSVParser(column_names=["col0", "col1", "col2"],
|
||||
default_values=["", "", 1.4])
|
||||
csv_lines = ["one,two,2.5", "four,five,6.0"]
|
||||
csv_input = tf.constant(csv_lines, dtype=tf.string, shape=[len(csv_lines)])
|
||||
csv_column = mocks.MockColumn("csv", csv_input)
|
||||
expected_output = [np.array([b"one", b"four"]),
|
||||
np.array([b"two", b"five"]),
|
||||
np.array([2.5, 6.0])]
|
||||
output_columns = parser(csv_column)
|
||||
self.assertEqual(3, len(output_columns))
|
||||
cache = {}
|
||||
output_tensors = [o.build(cache) for o in output_columns]
|
||||
self.assertEqual(3, len(output_tensors))
|
||||
with self.test_session() as sess:
|
||||
output = sess.run(output_tensors)
|
||||
for expected, actual in zip(expected_output, output):
|
||||
np.testing.assert_array_equal(actual, expected)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -0,0 +1,134 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for learn.dataframe.transforms.example_parser."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import example_parser
|
||||
from tensorflow.contrib.learn.python.learn.tests.dataframe import mocks
|
||||
from tensorflow.core.example import example_pb2
|
||||
|
||||
|
||||
class ExampleParserTestCase(tf.test.TestCase):
|
||||
"""Test class for `ExampleParser`."""
|
||||
|
||||
def setUp(self):
|
||||
super(ExampleParserTestCase, self).setUp()
|
||||
self.example1 = example_pb2.Example()
|
||||
text_format.Parse("features: { "
|
||||
" feature: { "
|
||||
" key: 'int_feature' "
|
||||
" value: { "
|
||||
" int64_list: { "
|
||||
" value: [ 21, 2, 5 ] "
|
||||
" } "
|
||||
" } "
|
||||
" } "
|
||||
" feature: { "
|
||||
" key: 'string_feature' "
|
||||
" value: { "
|
||||
" bytes_list: { "
|
||||
" value: [ 'armadillo' ] "
|
||||
" } "
|
||||
" } "
|
||||
" } "
|
||||
"} ", self.example1)
|
||||
self.example2 = example_pb2.Example()
|
||||
text_format.Parse("features: { "
|
||||
" feature: { "
|
||||
" key: 'int_feature' "
|
||||
" value: { "
|
||||
" int64_list: { "
|
||||
" value: [ 4, 5, 6 ] "
|
||||
" } "
|
||||
" } "
|
||||
" } "
|
||||
" feature: { "
|
||||
" key: 'string_feature' "
|
||||
" value: { "
|
||||
" bytes_list: { "
|
||||
" value: [ 'car', 'train' ] "
|
||||
" } "
|
||||
" } "
|
||||
" } "
|
||||
"} ", self.example2)
|
||||
self.example_column = mocks.MockColumn(
|
||||
"example",
|
||||
tf.constant(
|
||||
[self.example1.SerializeToString(),
|
||||
self.example2.SerializeToString()],
|
||||
dtype=tf.string,
|
||||
shape=[2]))
|
||||
self.features = (("string_feature", tf.VarLenFeature(dtype=tf.string)),
|
||||
("int_feature",
|
||||
tf.FixedLenFeature(shape=[3],
|
||||
dtype=tf.int64,
|
||||
default_value=[0, 0, 0])))
|
||||
|
||||
self.expected_string_values = np.array(
|
||||
list(self.example1.features.feature["string_feature"].bytes_list.value) +
|
||||
list(self.example2.features.feature["string_feature"].bytes_list.value))
|
||||
self.expected_string_indices = np.array([[0, 0], [1, 0], [1, 1]])
|
||||
self.expected_int_feature = np.array([list(self.example1.features.feature[
|
||||
"int_feature"].int64_list.value), list(self.example2.features.feature[
|
||||
"int_feature"].int64_list.value)])
|
||||
|
||||
def testParseWithTupleDefinition(self):
|
||||
parser = example_parser.ExampleParser(self.features)
|
||||
output_columns = parser(self.example_column)
|
||||
self.assertEqual(2, len(output_columns))
|
||||
cache = {}
|
||||
output_tensors = [o.build(cache) for o in output_columns]
|
||||
self.assertEqual(2, len(output_tensors))
|
||||
|
||||
with self.test_session() as sess:
|
||||
string_feature, int_feature = sess.run(output_tensors)
|
||||
np.testing.assert_array_equal(string_feature.shape, np.array([2, 2]))
|
||||
np.testing.assert_array_equal(int_feature.shape, np.array([2, 3]))
|
||||
np.testing.assert_array_equal(self.expected_string_values,
|
||||
string_feature.values)
|
||||
np.testing.assert_array_equal(self.expected_string_indices,
|
||||
string_feature.indices)
|
||||
np.testing.assert_array_equal(self.expected_int_feature,
|
||||
int_feature)
|
||||
|
||||
def testParseWithDictDefinition(self):
|
||||
parser = example_parser.ExampleParser(dict(self.features))
|
||||
output_columns = parser(self.example_column)
|
||||
self.assertEqual(2, len(output_columns))
|
||||
cache = {}
|
||||
output_tensors = [o.build(cache) for o in output_columns]
|
||||
self.assertEqual(2, len(output_tensors))
|
||||
|
||||
with self.test_session() as sess:
|
||||
int_feature, string_feature = sess.run(output_tensors)
|
||||
np.testing.assert_array_equal(string_feature.shape, np.array([2, 2]))
|
||||
np.testing.assert_array_equal(int_feature.shape, np.array([2, 3]))
|
||||
np.testing.assert_array_equal(self.expected_string_values,
|
||||
string_feature.values)
|
||||
np.testing.assert_array_equal(self.expected_string_indices,
|
||||
string_feature.indices)
|
||||
np.testing.assert_array_equal(self.expected_int_feature,
|
||||
int_feature)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
Loading…
Reference in New Issue
Block a user