Adds integration tests for DNNClassifier.

PiperOrigin-RevId: 157592010
This commit is contained in:
A. Unique TensorFlower 2017-05-31 08:57:42 -07:00 committed by TensorFlower Gardener
parent a05de6cd24
commit be5d98a8bd

View File

@ -695,6 +695,169 @@ class DNNRegressorIntegrationTest(test.TestCase):
batch_size=batch_size)
class DNNClassifierIntegrationTest(test.TestCase):
def setUp(self):
self._model_dir = tempfile.mkdtemp()
def tearDown(self):
if self._model_dir:
shutil.rmtree(self._model_dir)
def _test_complete_flow(
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
n_classes, batch_size):
feature_columns = [
feature_column.numeric_column('x', shape=(input_dimension,))]
est = dnn.DNNClassifier(
hidden_units=(2, 2),
feature_columns=feature_columns,
n_classes=n_classes,
model_dir=self._model_dir)
# TRAIN
num_steps = 10
est.train(train_input_fn, steps=num_steps)
# EVALUTE
scores = est.evaluate(eval_input_fn)
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
self.assertIn('loss', six.iterkeys(scores))
# PREDICT
predicted_proba = np.array([
x[prediction_keys.PredictionKeys.PROBABILITIES]
for x in est.predict(predict_input_fn)
])
self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
# EXPORT
feature_spec = feature_column.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
def test_numpy_input_fn(self):
"""Tests complete flow with numpy_input_fn."""
n_classes = 2
input_dimension = 2
batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
x_data = data.reshape(batch_size, input_dimension)
y_data = np.reshape(data[:batch_size], (batch_size, 1))
# learn y = x
train_input_fn = numpy_io.numpy_input_fn(
x={'x': x_data},
y=y_data,
batch_size=batch_size,
num_epochs=None,
shuffle=True)
eval_input_fn = numpy_io.numpy_input_fn(
x={'x': x_data},
y=y_data,
batch_size=batch_size,
shuffle=False)
predict_input_fn = numpy_io.numpy_input_fn(
x={'x': x_data},
batch_size=batch_size,
shuffle=False)
self._test_complete_flow(
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
batch_size=batch_size)
def test_pandas_input_fn(self):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
input_dimension = 1
n_classes = 2
batch_size = 10
data = np.linspace(0., 2., batch_size, dtype=np.float32)
x = pd.DataFrame({'x': data})
y = pd.Series(data)
train_input_fn = pandas_io.pandas_input_fn(
x=x,
y=y,
batch_size=batch_size,
num_epochs=None,
shuffle=True)
eval_input_fn = pandas_io.pandas_input_fn(
x=x,
y=y,
batch_size=batch_size,
shuffle=False)
predict_input_fn = pandas_io.pandas_input_fn(
x=x,
batch_size=batch_size,
shuffle=False)
self._test_complete_flow(
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
batch_size=batch_size)
def test_input_fn_from_parse_example(self):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 2
batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
data = data.reshape(batch_size, input_dimension)
serialized_examples = []
for datum in data:
example = example_pb2.Example(features=feature_pb2.Features(
feature={
'x': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=datum)),
'y': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=datum[:1])),
}))
serialized_examples.append(example.SerializeToString())
feature_spec = {
'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
'y': parsing_ops.FixedLenFeature([1], dtypes.float32),
}
def _train_input_fn():
feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
features = _queue_parsed_features(feature_map)
labels = features.pop('y')
return features, labels
def _eval_input_fn():
feature_map = parsing_ops.parse_example(
input_lib.limit_epochs(serialized_examples, num_epochs=1),
feature_spec)
features = _queue_parsed_features(feature_map)
labels = features.pop('y')
return features, labels
def _predict_input_fn():
feature_map = parsing_ops.parse_example(
input_lib.limit_epochs(serialized_examples, num_epochs=1),
feature_spec)
features = _queue_parsed_features(feature_map)
features.pop('y')
return features, None
self._test_complete_flow(
train_input_fn=_train_input_fn,
eval_input_fn=_eval_input_fn,
predict_input_fn=_predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
batch_size=batch_size)
def _full_var_name(var_name):
return '%s/part_0:0' % var_name