Fixed predictions in Linear/Dnn/Combined estimators.
Aligned usage of metrics with the metrics API. Added custom metrics and prediction tests. Change: 123178457
This commit is contained in:
parent
19d33a1a3e
commit
6430c245b8
@ -160,7 +160,7 @@ py_test(
|
||||
|
||||
py_test(
|
||||
name = "dnn_linear_combined_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["python/learn/estimators/dnn_linear_combined_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
|
@ -17,8 +17,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib import layers
|
||||
@ -103,6 +105,37 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
self._dnn_weight_collection = "DNNLinearCombined_dnn"
|
||||
self._linear_weight_collection = "DNNLinearCombined_linear"
|
||||
|
||||
def predict(self, x=None, input_fn=None, batch_size=None):
|
||||
"""Returns predictions for given features.
|
||||
|
||||
Args:
|
||||
x: features.
|
||||
input_fn: Input function. If set, x must be None.
|
||||
batch_size: Override default batch size.
|
||||
|
||||
Returns:
|
||||
Numpy array of predicted classes or regression values.
|
||||
"""
|
||||
predictions = self._infer_model(x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size)
|
||||
if self._n_classes > 1:
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
return predictions
|
||||
|
||||
def predict_proba(self, x=None, input_fn=None, batch_size=None):
|
||||
"""Returns prediction probabilities for given features (classification).
|
||||
|
||||
Args:
|
||||
x: features.
|
||||
input_fn: Input function. If set, x and y must be None.
|
||||
batch_size: Override default batch size.
|
||||
|
||||
Returns:
|
||||
Numpy array of predicted probabilities.
|
||||
"""
|
||||
return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size)
|
||||
|
||||
def _get_train_ops(self, features, targets):
|
||||
"""See base class."""
|
||||
global_step = variables.get_global_step()
|
||||
@ -126,45 +159,55 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
|
||||
with ops.get_default_graph().colocate_with(global_step):
|
||||
return state_ops.assign_add(global_step, 1).op, loss
|
||||
|
||||
def _get_eval_ops(self, features, targets, metrics):
|
||||
"""See base class."""
|
||||
predictions = self._get_predict_ops(features)
|
||||
def _run_metrics(self, predictions, targets, metrics, weights):
|
||||
result = {}
|
||||
targets = math_ops.cast(targets, predictions.dtype)
|
||||
for name, metric in six.iteritems(metrics):
|
||||
result[name] = metric(predictions, targets,
|
||||
self._get_weight_tensor(features))
|
||||
if "weights" in inspect.getargspec(metric)[0]:
|
||||
result[name] = metric(predictions, targets, weights=weights)
|
||||
else:
|
||||
result[name] = metric(predictions, targets)
|
||||
|
||||
return result
|
||||
|
||||
def _get_default_metric_functions(self):
|
||||
def _get_eval_ops(self, features, targets, metrics=None):
|
||||
"""See base class."""
|
||||
def _compute_loss(logits, targets, weights=None):
|
||||
return metrics_lib.streaming_mean(self._loss(
|
||||
logits, targets, weight_tensor=weights))
|
||||
logits = self._logits(features)
|
||||
result = {"loss": metrics_lib.streaming_mean(self._loss(
|
||||
logits, targets,
|
||||
weight_tensor=self._get_weight_tensor(features)))}
|
||||
|
||||
def _compute_accuracy(logits, targets, weights=None):
|
||||
if self._n_classes > 2:
|
||||
_, predictions = nn.top_k(logits, 1)
|
||||
else:
|
||||
predictions = array_ops.reshape(logits, [-1])
|
||||
predictions = math_ops.greater(predictions,
|
||||
array_ops.zeros_like(predictions))
|
||||
targets = array_ops.reshape(targets, [-1])
|
||||
return metrics_lib.streaming_accuracy(
|
||||
math_ops.to_int32(predictions), math_ops.to_int32(targets), weights)
|
||||
# Adding default metrics
|
||||
if metrics is None and self._n_classes > 1:
|
||||
metrics = {"accuracy": metrics_lib.streaming_accuracy}
|
||||
|
||||
def _compute_auc(logits, targets, unused_weights=None):
|
||||
return metrics_lib.streaming_auc(math_ops.sigmoid(logits), targets)
|
||||
if self._n_classes == 2:
|
||||
predictions = math_ops.sigmoid(logits)
|
||||
result["eval_auc"] = metrics_lib.streaming_auc(predictions, targets)
|
||||
|
||||
if metrics:
|
||||
predictions = self._logits_to_predictions(logits, proba=False)
|
||||
result.update(self._run_metrics(predictions, targets, metrics,
|
||||
self._get_weight_tensor(features)))
|
||||
|
||||
result = {"loss": _compute_loss}
|
||||
if self._n_classes > 1:
|
||||
result["accuracy"] = _compute_accuracy
|
||||
# Adds AUC for binary classification problems.
|
||||
if self._num_label_columns() == 1:
|
||||
result["eval_auc"] = _compute_auc
|
||||
return result
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
return self._logits(features)
|
||||
"""See base class."""
|
||||
logits = self._logits(features)
|
||||
return self._logits_to_predictions(logits, proba=True)
|
||||
|
||||
def _logits_to_predictions(self, logits, proba=False):
|
||||
if self._n_classes < 2:
|
||||
return array_ops.reshape(logits, [-1])
|
||||
|
||||
if self._n_classes == 2:
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
|
||||
if proba:
|
||||
return nn.softmax(logits)
|
||||
else:
|
||||
return math_ops.argmax(logits, 1)
|
||||
|
||||
def _get_feature_ops_from_example(self, examples_batch):
|
||||
column_types = layers.create_dict_for_parse_example(
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
|
||||
|
||||
|
||||
def _get_quantile_based_buckets(feature_values, num_buckets):
|
||||
@ -229,6 +230,58 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
|
||||
scores = classifier.evaluate(input_fn=_iris_input_fn, steps=100)
|
||||
self.assertGreater(scores['accuracy'], 0.9)
|
||||
|
||||
def testPredict(self):
|
||||
"""Tests weight column in evaluation."""
|
||||
|
||||
def _input_fn_train():
|
||||
# Create 4 rows, one of them (y = x), three of them (y=Not(x))
|
||||
target = tf.constant([[1], [0], [0], [0]])
|
||||
features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
|
||||
return features, target
|
||||
|
||||
def _input_fn_predict():
|
||||
features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
|
||||
return features
|
||||
|
||||
classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
|
||||
linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||
dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||
dnn_hidden_units=[3, 3])
|
||||
|
||||
classifier.train(input_fn=_input_fn_train, steps=100)
|
||||
probs = classifier.predict_proba(input_fn=_input_fn_predict)
|
||||
self.assertAllClose([[0.75, 0.25]] * 4, probs, 0.01)
|
||||
classes = classifier.predict(input_fn=_input_fn_predict)
|
||||
self.assertListEqual([0] * 4, list(classes))
|
||||
|
||||
def testCustomMetrics(self):
|
||||
"""Tests weight column in evaluation."""
|
||||
|
||||
def _input_fn_train():
|
||||
# Create 4 rows, one of them (y = x), three of them (y=Not(x))
|
||||
target = tf.constant([[1], [0], [0], [0]])
|
||||
features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
|
||||
return features, target
|
||||
|
||||
classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
|
||||
linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||
dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
|
||||
dnn_hidden_units=[3, 3])
|
||||
|
||||
classifier.train(input_fn=_input_fn_train, steps=100)
|
||||
scores = classifier.evaluate(
|
||||
input_fn=_input_fn_train,
|
||||
steps=100,
|
||||
metrics={
|
||||
'my_accuracy': tf.contrib.metrics.streaming_accuracy,
|
||||
'my_precision': tf.contrib.metrics.streaming_precision
|
||||
})
|
||||
self.assertTrue(set(['loss', 'my_accuracy', 'my_precision']).issubset(set(
|
||||
scores.keys())))
|
||||
predictions = classifier.predict(input_fn=_input_fn_train)
|
||||
self.assertEqual(_sklearn.accuracy_score([1, 0, 0, 0], predictions),
|
||||
scores['my_accuracy'])
|
||||
|
||||
|
||||
class DNNLinearCombinedRegressorTest(tf.test.TestCase):
|
||||
|
||||
|
@ -259,8 +259,7 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
Returns:
|
||||
Numpy array of predicted classes or regression values.
|
||||
"""
|
||||
return self._infer_model(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size)
|
||||
return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size)
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
@ -296,6 +295,8 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
def _get_eval_ops(self, features, targets, metrics):
|
||||
"""Method that builds model graph and returns evaluation ops.
|
||||
|
||||
Expected to be overriden by sub-classes that require custom support.
|
||||
|
||||
Args:
|
||||
features: `Tensor` or `dict` of `Tensor` objects.
|
||||
targets: `Tensor` or `dict` of `Tensor` objects.
|
||||
@ -304,11 +305,7 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
Returns:
|
||||
metrics: `dict` of `Tensor` objects.
|
||||
"""
|
||||
predictions = self._get_predict_ops(features)
|
||||
result = {}
|
||||
for name, metric in six.iteritems(metrics):
|
||||
result[name] = metric(predictions, targets)
|
||||
return result
|
||||
raise NotImplementedError('_get_eval_ops not implemented in BaseEstimator')
|
||||
|
||||
def _get_feature_ops_from_example(self, examples_batch):
|
||||
"""Method that returns features given the batch of examples.
|
||||
@ -324,16 +321,6 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
raise NotImplementedError('_get_feature_ops_from_example not implemented '
|
||||
'in BaseEstimator')
|
||||
|
||||
def _get_default_metric_functions(self):
|
||||
"""Method that provides default metric operations.
|
||||
|
||||
This functions is intented to be overridden by sub-classes.
|
||||
Returns:
|
||||
`dict` of functions that take predictions and targets `Tensor` objects and
|
||||
return `Tensor`.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _check_inputs(self, features, targets):
|
||||
if self._features_info is not None:
|
||||
if not tensor_signature.tensors_compatible(features, self._features_info):
|
||||
@ -460,9 +447,7 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
global_step = contrib_framework.create_global_step(g)
|
||||
features, targets = input_fn()
|
||||
self._check_inputs(features, targets)
|
||||
eval_dict = self._get_eval_ops(features, targets,
|
||||
metrics if metrics is not None else
|
||||
self._get_default_metric_functions())
|
||||
eval_dict = self._get_eval_ops(features, targets, metrics)
|
||||
update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
|
||||
eval_results, _ = evaluate(graph=g,
|
||||
output_dir=eval_dir,
|
||||
@ -475,9 +460,13 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
max_steps=steps)
|
||||
return eval_results
|
||||
|
||||
def _infer_model(self,
|
||||
x=None, input_fn=None, feed_fn=None,
|
||||
batch_size=None):
|
||||
def _get_features_from_input_fn(self, input_fn):
|
||||
result = input_fn()
|
||||
if isinstance(result, (list, tuple)):
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
def _infer_model(self, x=None, input_fn=None, feed_fn=None, batch_size=None):
|
||||
# Converts inputs into tf.DataFrame / tf.Series.
|
||||
batch_size = -1 if batch_size is None else batch_size
|
||||
if x is not None:
|
||||
@ -487,7 +476,7 @@ class BaseEstimator(sklearn.BaseEstimator):
|
||||
with ops.Graph().as_default() as g:
|
||||
random_seed.set_random_seed(self._config.tf_random_seed)
|
||||
contrib_framework.create_global_step(g)
|
||||
features, _ = input_fn()
|
||||
features = self._get_features_from_input_fn(input_fn)
|
||||
predictions = self._get_predict_ops(features)
|
||||
return_dict = True
|
||||
if not isinstance(predictions, dict):
|
||||
@ -570,7 +559,8 @@ class Estimator(BaseEstimator):
|
||||
Returns:
|
||||
Numpy array of predicted classes or regression values.
|
||||
"""
|
||||
predictions = self._infer_model(x=x, input_fn=input_fn,
|
||||
predictions = self._infer_model(x=x,
|
||||
input_fn=input_fn,
|
||||
batch_size=batch_size)
|
||||
if self._classification:
|
||||
for key in predictions:
|
||||
@ -589,8 +579,7 @@ class Estimator(BaseEstimator):
|
||||
Returns:
|
||||
Numpy array of predicted probabilities.
|
||||
"""
|
||||
return self._infer_model(x=x, input_fn=input_fn,
|
||||
batch_size=batch_size)
|
||||
return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size)
|
||||
|
||||
def _get_train_ops(self, features, targets):
|
||||
"""Method that builds model graph and returns trainer ops.
|
||||
@ -645,6 +634,9 @@ class Estimator(BaseEstimator):
|
||||
"""
|
||||
predictions, loss = self._model_fn(features, targets, ModeKeys.EVAL)
|
||||
result = {'loss': loss}
|
||||
if metrics is None:
|
||||
metrics = _EVAL_METRICS[
|
||||
'classification' if self._classification else 'regression']
|
||||
if isinstance(targets, dict) and len(targets) == 1:
|
||||
# Unpack single target into just tensor.
|
||||
targets = targets[targets.keys()[0]]
|
||||
@ -671,15 +663,6 @@ class Estimator(BaseEstimator):
|
||||
predictions, _ = self._model_fn(features, targets, ModeKeys.INFER)
|
||||
return predictions
|
||||
|
||||
def _get_default_metric_functions(self):
|
||||
"""Method that provides default metric operations.
|
||||
|
||||
Returns:
|
||||
a dictionary of metric operations.
|
||||
"""
|
||||
return _EVAL_METRICS[
|
||||
'classification' if self._classification else 'regression']
|
||||
|
||||
def _get_feature_ops_from_example(self, examples_batch):
|
||||
"""Unimplemented.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user