Add csv dataset example to get_started/regression.

PiperOrigin-RevId: 167754634
This commit is contained in:
Mark Daoust 2017-09-06 12:14:53 -07:00 committed by TensorFlower Gardener
parent 0f6a17c51e
commit acc7c00588
6 changed files with 215 additions and 87 deletions

View File

@ -146,6 +146,9 @@ for i in range(100):
assert i == value
```
Note: Currently, one-shot iterators are the only type that is easily usable
with an `Estimator`.
An **initializable** iterator requires you to run an explicit
`iterator.initializer` operation before using it. In exchange for this
inconvenience, it enables you to *parameterize* the definition of the dataset,
@ -452,6 +455,9 @@ dataset = dataset.flat_map(
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
```
For a full example of parsing a CSV file using datasets, see [`imports85.py`](https://www.tensorflow.org/code/tensorflow/examples/get_started/regression/imports85.py)
in @{$get_started/linear_regression}.
<!--
TODO(mrry): Add these sections.

View File

@ -28,15 +28,21 @@ STEPS = 5000
def main(argv):
"""Builds, trains, and evaluates the model."""
assert len(argv) == 1
(x_train, y_train), (x_test, y_test) = imports85.load_data()
(train, test) = imports85.dataset()
# Build the training input_fn.
input_train = tf.estimator.inputs.pandas_input_fn(
x=x_train, y=y_train, num_epochs=None, shuffle=True)
def input_train():
return (
# Shuffling with a buffer larger than the data set ensures
# that the examples are well mixed.
train.shuffle(1000).batch(128)
# Repeat forever
.repeat().make_one_shot_iterator().get_next())
# Build the validation input_fn.
input_test = tf.estimator.inputs.pandas_input_fn(
x=x_test, y=y_test, shuffle=True)
def input_test():
return (test.shuffle(1000).batch(128)
.make_one_shot_iterator().get_next())
# The first way assigns a unique weight to each category. To do this you must
# specify the category's vocabulary (values outside this specification will
@ -71,7 +77,7 @@ def main(argv):
# Train the model.
model.train(input_fn=input_train, steps=STEPS)
# Evaluate how the model performs on data it has not yet seen.
# Evaluate how the model performs on data it has not yet seen.
eval_result = model.evaluate(input_fn=input_test)
# The evaluation returns a Python dictionary. The "average_loss" key holds the

View File

@ -21,53 +21,149 @@ from __future__ import print_function
import collections
import numpy as np
import pandas as pd
import tensorflow as tf
header = collections.OrderedDict([
("symboling", np.int32),
("normalized-losses", np.float32),
("make", str),
("fuel-type", str),
("aspiration", str),
("num-of-doors", str),
("body-style", str),
("drive-wheels", str),
("engine-location", str),
("wheel-base", np.float32),
("length", np.float32),
("width", np.float32),
("height", np.float32),
("curb-weight", np.float32),
("engine-type", str),
("num-of-cylinders", str),
("engine-size", np.float32),
("fuel-system", str),
("bore", np.float32),
("stroke", np.float32),
("compression-ratio", np.float32),
("horsepower", np.float32),
("peak-rpm", np.float32),
("city-mpg", np.float32),
("highway-mpg", np.float32),
("price", np.float32)
try:
import pandas as pd # pylint: disable=g-import-not-at-top
except ImportError:
pass
URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data"
# Order is important for the csv-readers, so we use an OrderedDict here.
defaults = collections.OrderedDict([
("symboling", [0]),
("normalized-losses", [0.0]),
("make", [""]),
("fuel-type", [""]),
("aspiration", [""]),
("num-of-doors", [""]),
("body-style", [""]),
("drive-wheels", [""]),
("engine-location", [""]),
("wheel-base", [0.0]),
("length", [0.0]),
("width", [0.0]),
("height", [0.0]),
("curb-weight", [0.0]),
("engine-type", [""]),
("num-of-cylinders", [""]),
("engine-size", [0.0]),
("fuel-system", [""]),
("bore", [0.0]),
("stroke", [0.0]),
("compression-ratio", [0.0]),
("horsepower", [0.0]),
("peak-rpm", [0.0]),
("city-mpg", [0.0]),
("highway-mpg", [0.0]),
("price", [0.0])
]) # pyformat: disable
def raw():
"""Get the imports85 data and load it as a pd.DataFrame."""
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data" # pylint: disable=line-too-long
# Download and cache the data.
path = tf.contrib.keras.utils.get_file(url.split("/")[-1], url)
types = collections.OrderedDict((key, type(value[0]))
for key, value in defaults.items())
# Load the CSV data into a pandas dataframe.
df = pd.read_csv(path, names=header.keys(), dtype=header, na_values="?")
def _get_imports85():
path = tf.contrib.keras.utils.get_file(URL.split("/")[-1], URL)
return path
def dataset(y_name="price", train_fraction=0.7):
"""Load the imports85 data as a (train,test) pair of `Dataset`.
Each dataset generates (features_dict, label) pairs.
Args:
y_name: The name of the column to use as the label.
train_fraction: A float, the fraction of data to use for training. The
remainder will be used for evaluation.
Returns:
A (train,test) pair of `Datasets`
"""
# Download and cache the data
path = _get_imports85()
# Define how the lines of the file should be parsed
def decode_line(line):
"""Convert a csv line into a (features_dict,label) pair."""
# Decode the line to a tuple of items based on the types of
# csv_header.values().
items = tf.decode_csv(line, defaults.values())
# Convert the keys and items to a dict.
pairs = zip(defaults.keys(), items)
features_dict = dict(pairs)
# Remove the label from the features_dict
label = features_dict.pop(y_name)
return features_dict, label
def has_no_question_marks(line):
"""Returns True if the line of text has no question marks."""
# split the line into an array of characters
chars = tf.string_split(line[tf.newaxis], "").values
# for each character check if it is a question mark
is_question = tf.equal(chars, "?")
any_question = tf.reduce_any(is_question)
no_question = ~any_question
return no_question
def in_training_set(line):
"""Returns a boolean tensor, true if the line is in the training set."""
# If you randomly split the dataset you won't get the same split in both
# sessions if you stop and restart training later. Also a simple
# random split won't work with a dataset that's too big to `.cache()` as
# we are doing here.
num_buckets = 1000000
bucket_id = tf.string_to_hash_bucket_fast(line, num_buckets)
# Use the hash bucket id as a random number that's deterministic per example
return bucket_id < int(train_fraction * num_buckets)
def in_test_set(line):
"""Returns a boolean tensor, true if the line is in the training set."""
# Items not in the training set are in the test set.
# This line must use `~` instead of `not` beacuse `not` only works on python
# booleans but we are dealing with symbolic tensors.
return ~in_training_set(line)
base_dataset = (tf.contrib.data
# Get the lines from the file.
.TextLineDataset(path)
# drop lines with question marks.
.filter(has_no_question_marks))
train = (base_dataset
# Take only the training-set lines.
.filter(in_training_set)
# Cache data so you only read the file once.
.cache()
# Decode each line into a (features_dict, label) pair.
.map(decode_line))
# Do the same for the test-set.
test = (base_dataset.filter(in_test_set).cache().map(decode_line))
return train, test
def raw_dataframe():
"""Load the imports85 data as a pd.DataFrame."""
# Download and cache the data
path = _get_imports85()
# Load it into a pandas dataframe
df = pd.read_csv(path, names=types.keys(), dtype=types, na_values="?")
return df
def load_data(y_name="price", train_fraction=0.7, seed=None):
"""Returns the imports85 shuffled and split into train and test subsets.
"""Get the imports85 data set.
A description of the data is available at:
https://archive.ics.uci.edu/ml/datasets/automobile
@ -88,7 +184,7 @@ def load_data(y_name="price", train_fraction=0.7, seed=None):
array.
"""
# Load the raw data columns.
data = raw()
data = raw_dataframe()
# Delete rows with unknowns
data = data.dropna()

View File

@ -29,20 +29,21 @@ STEPS = 1000
def main(argv):
"""Builds, trains, and evaluates the model."""
assert len(argv) == 1
(x_train, y_train), (x_test, y_test) = imports85.load_data()
(train, test) = imports85.dataset()
# Build the training input_fn.
input_train = tf.estimator.inputs.pandas_input_fn(
x=x_train,
y=y_train,
# Setting `num_epochs` to `None` lets the `inpuf_fn` generate data
# indefinitely, leaving the call to `Estimator.train` in control.
num_epochs=None,
shuffle=True)
def input_train():
return (
# Shuffling with a buffer larger than the data set ensures
# that the examples are well mixed.
train.shuffle(1000).batch(128)
# Repeat forever
.repeat().make_one_shot_iterator().get_next())
# Build the validation input_fn.
input_test = tf.estimator.inputs.pandas_input_fn(
x=x_test, y=y_test, shuffle=True)
def input_test():
return (test.shuffle(1000).batch(128)
.make_one_shot_iterator().get_next())
feature_columns = [
# "curb-weight" and "highway-mpg" are numeric columns.

View File

@ -28,20 +28,21 @@ STEPS = 1000
def main(argv):
"""Builds, trains, and evaluates the model."""
assert len(argv) == 1
(x_train, y_train), (x_test, y_test) = imports85.load_data()
(train, test) = imports85.dataset()
# Build the training input_fn.
input_train = tf.estimator.inputs.pandas_input_fn(
x=x_train,
y=y_train,
# Setting `num_epochs` to `None` lets the `inpuf_fn` generate data
# indefinitely, leaving the call to `Estimator.train` in control.
num_epochs=None,
shuffle=True)
def input_train():
return (
# Shuffling with a buffer larger than the data set ensures
# that the examples are well mixed.
train.shuffle(1000).batch(128)
# Repeat forever
.repeat().make_one_shot_iterator().get_next())
# Build the validation input_fn.
input_test = tf.estimator.inputs.pandas_input_fn(
x=x_test, y=y_test, shuffle=True)
def input_test():
return (test.shuffle(1000).batch(128)
.make_one_shot_iterator().get_next())
# The following code demonstrates two of the ways that `feature_columns` can
# be used to build a model with categorical inputs.

View File

@ -26,48 +26,66 @@ from six.moves import StringIO
import tensorflow.examples.get_started.regression.imports85 as imports85
import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression # pylint: disable=g-bad-import-order,g-import-not-at-top
sys.modules["imports85"] = imports85
# pylint: disable=g-bad-import-order,g-import-not-at-top
import tensorflow.contrib.data as data
import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression
import tensorflow.examples.get_started.regression.linear_regression as linear_regression
import tensorflow.examples.get_started.regression.linear_regression_categorical as linear_regression_categorical
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
# pylint: disable=g-bad-import-order,g-import-not-at-top
def four_lines():
# pylint: disable=line-too-long
text = StringIO("""
1,?,alfa-romero,gas,std,two,hatchback,rwd,front,94.50,171.20,65.50,52.40,2823,ohcv,six,152,mpfi,2.68,3.47,9.00,154,5000,19,26,16500
2,164,audi,gas,std,four,sedan,fwd,front,99.80,176.60,66.20,54.30,2337,ohc,four,109,mpfi,3.19,3.40,10.00,102,5500,24,30,13950
2,164,audi,gas,std,four,sedan,4wd,front,99.40,176.60,66.40,54.30,2824,ohc,five,136,mpfi,3.19,3.40,8.00,115,5500,18,22,17450
2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250""")
# pylint: enable=line-too-long
# pylint: disable=line-too-long
FOUR_LINES = "\n".join([
"1,?,alfa-romero,gas,std,two,hatchback,rwd,front,94.50,171.20,65.50,52.40,2823,ohcv,six,152,mpfi,2.68,3.47,9.00,154,5000,19,26,16500",
"2,164,audi,gas,std,four,sedan,fwd,front,99.80,176.60,66.20,54.30,2337,ohc,four,109,mpfi,3.19,3.40,10.00,102,5500,24,30,13950",
"2,164,audi,gas,std,four,sedan,4wd,front,99.40,176.60,66.40,54.30,2824,ohc,five,136,mpfi,3.19,3.40,8.00,115,5500,18,22,17450",
"2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250",])
return pd.read_csv(text, names=imports85.header.keys(),
dtype=imports85.header, na_values='?')
# pylint: enable=line-too-long
def four_lines_dataframe():
text = StringIO(FOUR_LINES)
return pd.read_csv(text, names=imports85.types.keys(),
dtype=imports85.types, na_values="?")
def four_lines_dataset(*args, **kwargs):
del args, kwargs
return data.Dataset.from_tensor_slices(FOUR_LINES.split("\n"))
class RegressionTest(googletest.TestCase):
"""Test the regression examples in this directory."""
@test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
@test.mock.patch.dict(linear_regression.__dict__, {'STEPS': 1})
@test.mock.patch.dict(sys.modules, {'imports85': imports85})
@test.mock.patch.dict(data.__dict__,
{"TextLineDataset": four_lines_dataset})
@test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)})
@test.mock.patch.dict(linear_regression.__dict__, {"STEPS": 1})
def test_linear_regression(self):
linear_regression.main([])
linear_regression.main([""])
@test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
@test.mock.patch.dict(linear_regression_categorical.__dict__, {'STEPS': 1})
@test.mock.patch.dict(sys.modules, {'imports85': imports85})
@test.mock.patch.dict(data.__dict__,
{"TextLineDataset": four_lines_dataset})
@test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)})
@test.mock.patch.dict(linear_regression_categorical.__dict__, {"STEPS": 1})
def test_linear_regression_categorical(self):
linear_regression_categorical.main([])
linear_regression_categorical.main([""])
@test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
@test.mock.patch.dict(dnn_regression.__dict__, {'STEPS': 1})
@test.mock.patch.dict(sys.modules, {'imports85': imports85})
@test.mock.patch.dict(data.__dict__,
{"TextLineDataset": four_lines_dataset})
@test.mock.patch.dict(imports85.__dict__, {"_get_imports85": (lambda: None)})
@test.mock.patch.dict(dnn_regression.__dict__, {"STEPS": 1})
def test_dnn_regression(self):
dnn_regression.main([])
dnn_regression.main([""])
if __name__ == '__main__':
if __name__ == "__main__":
googletest.main()