Add examples supporting docs in get_started.
PiperOrigin-RevId: 163994110
This commit is contained in:
parent
be759452c6
commit
5951ab51a9
@ -372,6 +372,7 @@ filegroup(
|
||||
"//tensorflow/core/util/tensor_bundle:all_files",
|
||||
"//tensorflow/examples/android:all_files",
|
||||
"//tensorflow/examples/benchmark:all_files",
|
||||
"//tensorflow/examples/get_started/regression:all_files",
|
||||
"//tensorflow/examples/how_tos/reading_data:all_files",
|
||||
"//tensorflow/examples/image_retraining:all_files",
|
||||
"//tensorflow/examples/label_image:all_files",
|
||||
|
19
tensorflow/examples/get_started/__init__.py
Normal file
19
tensorflow/examples/get_started/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A collection of "getting started" examples."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
37
tensorflow/examples/get_started/regression/BUILD
Normal file
37
tensorflow/examples/get_started/regression/BUILD
Normal file
@ -0,0 +1,37 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"dnn_regression.py",
|
||||
"imports85.py",
|
||||
"linear_regression.py",
|
||||
"linear_regression_categorical.py",
|
||||
"test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"//third_party/py/pandas",
|
||||
],
|
||||
)
|
20
tensorflow/examples/get_started/regression/__init__.py
Normal file
20
tensorflow/examples/get_started/regression/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A collection of regression examples using `Estimators`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
91
tensorflow/examples/get_started/regression/dnn_regression.py
Normal file
91
tensorflow/examples/get_started/regression/dnn_regression.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Regression using the DNNRegressor Estimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import imports85 # pylint: disable=g-bad-import-order
|
||||
|
||||
STEPS = 5000
|
||||
|
||||
|
||||
def main(argv):
|
||||
"""Builds, trains, and evaluates the model."""
|
||||
assert len(argv) == 1
|
||||
(x_train, y_train), (x_test, y_test) = imports85.load_data()
|
||||
|
||||
# Build the training input_fn.
|
||||
input_train = tf.estimator.inputs.pandas_input_fn(
|
||||
x=x_train, y=y_train, num_epochs=None, shuffle=True)
|
||||
|
||||
# Build the validation input_fn.
|
||||
input_test = tf.estimator.inputs.pandas_input_fn(
|
||||
x=x_test, y=y_test, shuffle=True)
|
||||
|
||||
# The first way assigns a unique weight to each category. To do this you must
|
||||
# specify the category's vocabulary (values outside this specification will
|
||||
# receive a weight of zero). Here we specify the vocabulary using a list of
|
||||
# options. The vocabulary can also be specified with a vocabulary file (using
|
||||
# `categorical_column_with_vocabulary_file`). For features covering a
|
||||
# range of positive integers use `categorical_column_with_identity`.
|
||||
body_style_vocab = ["hardtop", "wagon", "sedan", "hatchback", "convertible"]
|
||||
body_style = tf.feature_column.categorical_column_with_vocabulary_list(
|
||||
key="body-style", vocabulary_list=body_style_vocab)
|
||||
make = tf.feature_column.categorical_column_with_hash_bucket(
|
||||
key="make", hash_bucket_size=50)
|
||||
|
||||
feature_columns = [
|
||||
tf.feature_column.numeric_column(key="curb-weight"),
|
||||
tf.feature_column.numeric_column(key="highway-mpg"),
|
||||
# Since this is a DNN model, convert categorical columns from sparse
|
||||
# to dense.
|
||||
# Wrap them in an `indicator_column` to create a
|
||||
# one-hot vector from the input.
|
||||
tf.feature_column.indicator_column(body_style),
|
||||
# Or use an `embedding_column` to create a trainable vector for each
|
||||
# index.
|
||||
tf.feature_column.embedding_column(make, dimension=3),
|
||||
]
|
||||
|
||||
# Build a DNNRegressor, with 2x20-unit hidden layers, with the feature columns
|
||||
# defined above as input.
|
||||
model = tf.estimator.DNNRegressor(
|
||||
hidden_units=[20, 20], feature_columns=feature_columns)
|
||||
|
||||
# Train the model.
|
||||
model.train(input_fn=input_train, steps=STEPS)
|
||||
|
||||
# Evaluate how the model performs on data it has not yet seen.
|
||||
eval_result = model.evaluate(input_fn=input_test)
|
||||
|
||||
# The evaluation returns a Python dictionary. The "average_loss" key holds the
|
||||
# Mean Squared Error (MSE).
|
||||
average_loss = eval_result["average_loss"]
|
||||
|
||||
# Convert MSE to Root Mean Square Error (RMSE).
|
||||
print("\n" + 80 * "*")
|
||||
print("\nRMS error for the test set: ${:.0f}".format(average_loss**0.5))
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The Estimator periodically generates "INFO" logs; make these logs visible.
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
tf.app.run(main=main)
|
107
tensorflow/examples/get_started/regression/imports85.py
Normal file
107
tensorflow/examples/get_started/regression/imports85.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A dataset loader for imports85.data."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
|
||||
header = collections.OrderedDict([
|
||||
("symboling", np.int32),
|
||||
("normalized-losses", np.float32),
|
||||
("make", str),
|
||||
("fuel-type", str),
|
||||
("aspiration", str),
|
||||
("num-of-doors", str),
|
||||
("body-style", str),
|
||||
("drive-wheels", str),
|
||||
("engine-location", str),
|
||||
("wheel-base", np.float32),
|
||||
("length", np.float32),
|
||||
("width", np.float32),
|
||||
("height", np.float32),
|
||||
("curb-weight", np.float32),
|
||||
("engine-type", str),
|
||||
("num-of-cylinders", str),
|
||||
("engine-size", np.float32),
|
||||
("fuel-system", str),
|
||||
("bore", np.float32),
|
||||
("stroke", np.float32),
|
||||
("compression-ratio", np.float32),
|
||||
("horsepower", np.float32),
|
||||
("peak-rpm", np.float32),
|
||||
("city-mpg", np.float32),
|
||||
("highway-mpg", np.float32),
|
||||
("price", np.float32)
|
||||
]) # pyformat: disable
|
||||
|
||||
|
||||
def raw():
|
||||
"""Get the imports85 data and load it as a pd.DataFrame."""
|
||||
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data" # pylint: disable=line-too-long
|
||||
# Download and cache the data.
|
||||
path = tf.contrib.keras.utils.get_file(url.split("/")[-1], url)
|
||||
|
||||
# Load the CSV data into a pandas dataframe.
|
||||
df = pd.read_csv(path, names=header.keys(), dtype=header, na_values="?")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def load_data(y_name="price", train_fraction=0.7, seed=None):
|
||||
"""Returns the imports85 shuffled and split into train and test subsets.
|
||||
|
||||
A description of the data is available at:
|
||||
https://archive.ics.uci.edu/ml/datasets/automobile
|
||||
|
||||
The data itself can be found at:
|
||||
https://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.data
|
||||
|
||||
Args:
|
||||
y_name: the column to return as the label.
|
||||
train_fraction: the fraction of the dataset to use for training.
|
||||
seed: The random seed to use when shuffling the data. `None` generates a
|
||||
unique shuffle every run.
|
||||
Returns:
|
||||
a pair of pairs where the first pair is the training data, and the second
|
||||
is the test data:
|
||||
`(x_train, y_train), (x_test, y_test) = get_imports85_dataset(...)`
|
||||
`x` contains a pandas DataFrame of features, while `y` contains the label
|
||||
array.
|
||||
"""
|
||||
# Load the raw data columns.
|
||||
data = raw()
|
||||
|
||||
# Delete rows with unknowns
|
||||
data = data.dropna()
|
||||
|
||||
# Shuffle the data
|
||||
np.random.seed(seed)
|
||||
|
||||
# Split the data into train/test subsets.
|
||||
x_train = data.sample(frac=train_fraction, random_state=seed)
|
||||
x_test = data.drop(x_train.index)
|
||||
|
||||
# Extract the label from the features dataframe.
|
||||
y_train = x_train.pop(y_name)
|
||||
y_test = x_test.pop(y_name)
|
||||
|
||||
return (x_train, y_train), (x_test, y_test)
|
@ -0,0 +1,96 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Linear regression using the LinearRegressor Estimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import imports85 # pylint: disable=g-bad-import-order
|
||||
|
||||
STEPS = 1000
|
||||
|
||||
|
||||
def main(argv):
|
||||
"""Builds, trains, and evaluates the model."""
|
||||
assert len(argv) == 1
|
||||
(x_train, y_train), (x_test, y_test) = imports85.load_data()
|
||||
|
||||
# Build the training input_fn.
|
||||
input_train = tf.estimator.inputs.pandas_input_fn(
|
||||
x=x_train,
|
||||
y=y_train,
|
||||
# Setting `num_epochs` to `None` lets the `inpuf_fn` generate data
|
||||
# indefinitely, leaving the call to `Estimator.train` in control.
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
|
||||
# Build the validation input_fn.
|
||||
input_test = tf.estimator.inputs.pandas_input_fn(
|
||||
x=x_test, y=y_test, shuffle=True)
|
||||
|
||||
feature_columns = [
|
||||
# "curb-weight" and "highway-mpg" are numeric columns.
|
||||
tf.feature_column.numeric_column(key="curb-weight"),
|
||||
tf.feature_column.numeric_column(key="highway-mpg"),
|
||||
]
|
||||
|
||||
# Build the Estimator.
|
||||
model = tf.estimator.LinearRegressor(feature_columns=feature_columns)
|
||||
|
||||
# Train the model.
|
||||
# By default, the Estimators log output every 100 steps.
|
||||
model.train(input_fn=input_train, steps=STEPS)
|
||||
|
||||
# Evaluate how the model performs on data it has not yet seen.
|
||||
eval_result = model.evaluate(input_fn=input_test)
|
||||
|
||||
# The evaluation returns a Python dictionary. The "average_loss" key holds the
|
||||
# Mean Squared Error (MSE).
|
||||
average_loss = eval_result["average_loss"]
|
||||
|
||||
# Convert MSE to Root Mean Square Error (RMSE).
|
||||
print("\n" + 80 * "*")
|
||||
print("\nRMS error for the test set: ${:.0f}".format(average_loss**0.5))
|
||||
|
||||
# Run the model in prediction mode.
|
||||
input_dict = {
|
||||
"curb-weight": np.array([2000, 3000]),
|
||||
"highway-mpg": np.array([30, 40])
|
||||
}
|
||||
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
input_dict, shuffle=False)
|
||||
predict_results = model.predict(input_fn=predict_input_fn)
|
||||
|
||||
# Print the prediction results.
|
||||
print("\nPrediction results:")
|
||||
for i, prediction in enumerate(predict_results):
|
||||
msg = ("Curb weight: {: 4d}lbs, "
|
||||
"Highway: {: 0d}mpg, "
|
||||
"Prediction: ${: 9.2f}")
|
||||
msg = msg.format(input_dict["curb-weight"][i], input_dict["highway-mpg"][i],
|
||||
prediction["predictions"][0])
|
||||
|
||||
print(" " + msg)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The Estimator periodically generates "INFO" logs; make these logs visible.
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
tf.app.run(main=main)
|
@ -0,0 +1,101 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Linear regression with categorical features."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import imports85 # pylint: disable=g-bad-import-order
|
||||
|
||||
STEPS = 1000
|
||||
|
||||
|
||||
def main(argv):
|
||||
"""Builds, trains, and evaluates the model."""
|
||||
assert len(argv) == 1
|
||||
(x_train, y_train), (x_test, y_test) = imports85.load_data()
|
||||
|
||||
# Build the training input_fn.
|
||||
input_train = tf.estimator.inputs.pandas_input_fn(
|
||||
x=x_train,
|
||||
y=y_train,
|
||||
# Setting `num_epochs` to `None` lets the `inpuf_fn` generate data
|
||||
# indefinitely, leaving the call to `Estimator.train` in control.
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
|
||||
# Build the validation input_fn.
|
||||
input_test = tf.estimator.inputs.pandas_input_fn(
|
||||
x=x_test, y=y_test, shuffle=True)
|
||||
|
||||
# The following code demonstrates two of the ways that `feature_columns` can
|
||||
# be used to build a model with categorical inputs.
|
||||
|
||||
# The first way assigns a unique weight to each category. To do this, you must
|
||||
# specify the category's vocabulary (values outside this specification will
|
||||
# receive a weight of zero).
|
||||
# Alternatively, you can define the vocabulary in a file (by calling
|
||||
# `categorical_column_with_vocabulary_file`) or as a range of positive
|
||||
# integers (by calling `categorical_column_with_identity`)
|
||||
body_style_vocab = ["hardtop", "wagon", "sedan", "hatchback", "convertible"]
|
||||
body_style_column = tf.feature_column.categorical_column_with_vocabulary_list(
|
||||
key="body-style", vocabulary_list=body_style_vocab)
|
||||
|
||||
# The second way, appropriate for an unspecified vocabulary, is to create a
|
||||
# hashed column. It will create a fixed length list of weights, and
|
||||
# automatically assign each input categort to a weight. Due to the
|
||||
# pseudo-randomness of the process, some weights may be shared between
|
||||
# categories, while others will remain unused.
|
||||
make_column = tf.feature_column.categorical_column_with_hash_bucket(
|
||||
key="make", hash_bucket_size=50)
|
||||
|
||||
feature_columns = [
|
||||
# This model uses the same two numeric features as `linear_regressor.py`
|
||||
tf.feature_column.numeric_column(key="curb-weight"),
|
||||
tf.feature_column.numeric_column(key="highway-mpg"),
|
||||
# This model adds two categorical colums that will adjust the price based
|
||||
# on "make" and "body-style".
|
||||
body_style_column,
|
||||
make_column,
|
||||
]
|
||||
|
||||
# Build the Estimator.
|
||||
model = tf.estimator.LinearRegressor(feature_columns=feature_columns)
|
||||
|
||||
# Train the model.
|
||||
# By default, the Estimators log output every 100 steps.
|
||||
model.train(input_fn=input_train, steps=STEPS)
|
||||
|
||||
# Evaluate how the model performs on data it has not yet seen.
|
||||
eval_result = model.evaluate(input_fn=input_test)
|
||||
|
||||
# The evaluation returns a Python dictionary. The "average_loss" key holds the
|
||||
# Mean Squared Error (MSE).
|
||||
average_loss = eval_result["average_loss"]
|
||||
|
||||
# Convert MSE to Root Mean Square Error (RMSE).
|
||||
print("\n" + 80 * "*")
|
||||
print("\nRMS error for the test set: ${:.0f}".format(average_loss**0.5))
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The Estimator periodically generates "INFO" logs; make these logs visible.
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
tf.app.run(main=main)
|
73
tensorflow/examples/get_started/regression/test.py
Normal file
73
tensorflow/examples/get_started/regression/test.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A simple smoke test that runs these examples for 1 training iteraton."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from six.moves import StringIO
|
||||
|
||||
import tensorflow.examples.get_started.regression.imports85 as imports85
|
||||
|
||||
import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression # pylint: disable=g-bad-import-order,g-import-not-at-top
|
||||
import tensorflow.examples.get_started.regression.linear_regression as linear_regression
|
||||
import tensorflow.examples.get_started.regression.linear_regression_categorical as linear_regression_categorical
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def four_lines():
|
||||
# pylint: disable=line-too-long
|
||||
text = StringIO("""
|
||||
1,?,alfa-romero,gas,std,two,hatchback,rwd,front,94.50,171.20,65.50,52.40,2823,ohcv,six,152,mpfi,2.68,3.47,9.00,154,5000,19,26,16500
|
||||
2,164,audi,gas,std,four,sedan,fwd,front,99.80,176.60,66.20,54.30,2337,ohc,four,109,mpfi,3.19,3.40,10.00,102,5500,24,30,13950
|
||||
2,164,audi,gas,std,four,sedan,4wd,front,99.40,176.60,66.40,54.30,2824,ohc,five,136,mpfi,3.19,3.40,8.00,115,5500,18,22,17450
|
||||
2,?,audi,gas,std,two,sedan,fwd,front,99.80,177.30,66.30,53.10,2507,ohc,five,136,mpfi,3.19,3.40,8.50,110,5500,19,25,15250""")
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
return pd.read_csv(text, names=imports85.header.keys(),
|
||||
dtype=imports85.header, na_values='?')
|
||||
|
||||
|
||||
class RegressionTest(googletest.TestCase):
|
||||
"""Test the regression examples in this directory."""
|
||||
|
||||
@test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
|
||||
@test.mock.patch.dict(linear_regression.__dict__, {'STEPS': 1})
|
||||
@test.mock.patch.dict(sys.modules, {'imports85': imports85})
|
||||
def test_linear_regression(self):
|
||||
linear_regression.main([])
|
||||
|
||||
@test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
|
||||
@test.mock.patch.dict(linear_regression_categorical.__dict__, {'STEPS': 1})
|
||||
@test.mock.patch.dict(sys.modules, {'imports85': imports85})
|
||||
def test_linear_regression_categorical(self):
|
||||
linear_regression_categorical.main([])
|
||||
|
||||
@test.mock.patch.dict(imports85.__dict__, {'raw': four_lines})
|
||||
@test.mock.patch.dict(dnn_regression.__dict__, {'STEPS': 1})
|
||||
@test.mock.patch.dict(sys.modules, {'imports85': imports85})
|
||||
def test_dnn_regression(self):
|
||||
dnn_regression.main([])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
Loading…
Reference in New Issue
Block a user