From 6430c245b8157ae3d63c399ac639bb51ae1e83da Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Tue, 24 May 2016 20:35:13 -0800 Subject: [PATCH] Fixed predictions in Linear/Dnn/Combined estimators. Aligned usage of metrics with the metrics API. Added custom metrics and prediction tests. Change: 123178457 --- tensorflow/contrib/learn/BUILD | 2 +- .../learn/estimators/dnn_linear_combined.py | 99 +++++++++++++------ .../estimators/dnn_linear_combined_test.py | 53 ++++++++++ .../python/learn/estimators/estimator.py | 55 ++++------- 4 files changed, 144 insertions(+), 65 deletions(-) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 3b21db1bab2..052d2d81b35 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -160,7 +160,7 @@ py_test( py_test( name = "dnn_linear_combined_test", - size = "small", + size = "medium", srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index ccb5ffae647..b1f34fa7400 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -17,8 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import inspect import math +import numpy as np import six from tensorflow.contrib import layers @@ -103,6 +105,37 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator): self._dnn_weight_collection = "DNNLinearCombined_dnn" self._linear_weight_collection = "DNNLinearCombined_linear" + def predict(self, x=None, input_fn=None, batch_size=None): + """Returns predictions for given features. + + Args: + x: features. + input_fn: Input function. If set, x must be None. + batch_size: Override default batch size. + + Returns: + Numpy array of predicted classes or regression values. + """ + predictions = self._infer_model(x=x, + input_fn=input_fn, + batch_size=batch_size) + if self._n_classes > 1: + predictions = np.argmax(predictions, axis=1) + return predictions + + def predict_proba(self, x=None, input_fn=None, batch_size=None): + """Returns prediction probabilities for given features (classification). + + Args: + x: features. + input_fn: Input function. If set, x and y must be None. + batch_size: Override default batch size. + + Returns: + Numpy array of predicted probabilities. + """ + return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size) + def _get_train_ops(self, features, targets): """See base class.""" global_step = variables.get_global_step() @@ -126,45 +159,55 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator): with ops.get_default_graph().colocate_with(global_step): return state_ops.assign_add(global_step, 1).op, loss - def _get_eval_ops(self, features, targets, metrics): - """See base class.""" - predictions = self._get_predict_ops(features) + def _run_metrics(self, predictions, targets, metrics, weights): result = {} + targets = math_ops.cast(targets, predictions.dtype) for name, metric in six.iteritems(metrics): - result[name] = metric(predictions, targets, - self._get_weight_tensor(features)) + if "weights" in inspect.getargspec(metric)[0]: + result[name] = metric(predictions, targets, weights=weights) + else: + result[name] = metric(predictions, targets) + return result - def _get_default_metric_functions(self): + def _get_eval_ops(self, features, targets, metrics=None): """See base class.""" - def _compute_loss(logits, targets, weights=None): - return metrics_lib.streaming_mean(self._loss( - logits, targets, weight_tensor=weights)) + logits = self._logits(features) + result = {"loss": metrics_lib.streaming_mean(self._loss( + logits, targets, + weight_tensor=self._get_weight_tensor(features)))} - def _compute_accuracy(logits, targets, weights=None): - if self._n_classes > 2: - _, predictions = nn.top_k(logits, 1) - else: - predictions = array_ops.reshape(logits, [-1]) - predictions = math_ops.greater(predictions, - array_ops.zeros_like(predictions)) - targets = array_ops.reshape(targets, [-1]) - return metrics_lib.streaming_accuracy( - math_ops.to_int32(predictions), math_ops.to_int32(targets), weights) + # Adding default metrics + if metrics is None and self._n_classes > 1: + metrics = {"accuracy": metrics_lib.streaming_accuracy} - def _compute_auc(logits, targets, unused_weights=None): - return metrics_lib.streaming_auc(math_ops.sigmoid(logits), targets) + if self._n_classes == 2: + predictions = math_ops.sigmoid(logits) + result["eval_auc"] = metrics_lib.streaming_auc(predictions, targets) + + if metrics: + predictions = self._logits_to_predictions(logits, proba=False) + result.update(self._run_metrics(predictions, targets, metrics, + self._get_weight_tensor(features))) - result = {"loss": _compute_loss} - if self._n_classes > 1: - result["accuracy"] = _compute_accuracy - # Adds AUC for binary classification problems. - if self._num_label_columns() == 1: - result["eval_auc"] = _compute_auc return result def _get_predict_ops(self, features): - return self._logits(features) + """See base class.""" + logits = self._logits(features) + return self._logits_to_predictions(logits, proba=True) + + def _logits_to_predictions(self, logits, proba=False): + if self._n_classes < 2: + return array_ops.reshape(logits, [-1]) + + if self._n_classes == 2: + logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) + + if proba: + return nn.softmax(logits) + else: + return math_ops.argmax(logits, 1) def _get_feature_ops_from_example(self, examples_batch): column_types = layers.create_dict_for_parse_example( diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 16d86bf73d7..d4188a44ac9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np import tensorflow as tf +from tensorflow.contrib.learn.python.learn.estimators import _sklearn def _get_quantile_based_buckets(feature_values, num_buckets): @@ -229,6 +230,58 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase): scores = classifier.evaluate(input_fn=_iris_input_fn, steps=100) self.assertGreater(scores['accuracy'], 0.9) + def testPredict(self): + """Tests weight column in evaluation.""" + + def _input_fn_train(): + # Create 4 rows, one of them (y = x), three of them (y=Not(x)) + target = tf.constant([[1], [0], [0], [0]]) + features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),} + return features, target + + def _input_fn_predict(): + features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),} + return features + + classifier = tf.contrib.learn.DNNLinearCombinedClassifier( + linear_feature_columns=[tf.contrib.layers.real_valued_column('x')], + dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')], + dnn_hidden_units=[3, 3]) + + classifier.train(input_fn=_input_fn_train, steps=100) + probs = classifier.predict_proba(input_fn=_input_fn_predict) + self.assertAllClose([[0.75, 0.25]] * 4, probs, 0.01) + classes = classifier.predict(input_fn=_input_fn_predict) + self.assertListEqual([0] * 4, list(classes)) + + def testCustomMetrics(self): + """Tests weight column in evaluation.""" + + def _input_fn_train(): + # Create 4 rows, one of them (y = x), three of them (y=Not(x)) + target = tf.constant([[1], [0], [0], [0]]) + features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),} + return features, target + + classifier = tf.contrib.learn.DNNLinearCombinedClassifier( + linear_feature_columns=[tf.contrib.layers.real_valued_column('x')], + dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')], + dnn_hidden_units=[3, 3]) + + classifier.train(input_fn=_input_fn_train, steps=100) + scores = classifier.evaluate( + input_fn=_input_fn_train, + steps=100, + metrics={ + 'my_accuracy': tf.contrib.metrics.streaming_accuracy, + 'my_precision': tf.contrib.metrics.streaming_precision + }) + self.assertTrue(set(['loss', 'my_accuracy', 'my_precision']).issubset(set( + scores.keys()))) + predictions = classifier.predict(input_fn=_input_fn_train) + self.assertEqual(_sklearn.accuracy_score([1, 0, 0, 0], predictions), + scores['my_accuracy']) + class DNNLinearCombinedRegressorTest(tf.test.TestCase): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index d3e122f17eb..0fce7d140f1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -259,8 +259,7 @@ class BaseEstimator(sklearn.BaseEstimator): Returns: Numpy array of predicted classes or regression values. """ - return self._infer_model(x=x, input_fn=input_fn, - batch_size=batch_size) + return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size) @property def model_dir(self): @@ -296,6 +295,8 @@ class BaseEstimator(sklearn.BaseEstimator): def _get_eval_ops(self, features, targets, metrics): """Method that builds model graph and returns evaluation ops. + Expected to be overriden by sub-classes that require custom support. + Args: features: `Tensor` or `dict` of `Tensor` objects. targets: `Tensor` or `dict` of `Tensor` objects. @@ -304,11 +305,7 @@ class BaseEstimator(sklearn.BaseEstimator): Returns: metrics: `dict` of `Tensor` objects. """ - predictions = self._get_predict_ops(features) - result = {} - for name, metric in six.iteritems(metrics): - result[name] = metric(predictions, targets) - return result + raise NotImplementedError('_get_eval_ops not implemented in BaseEstimator') def _get_feature_ops_from_example(self, examples_batch): """Method that returns features given the batch of examples. @@ -324,16 +321,6 @@ class BaseEstimator(sklearn.BaseEstimator): raise NotImplementedError('_get_feature_ops_from_example not implemented ' 'in BaseEstimator') - def _get_default_metric_functions(self): - """Method that provides default metric operations. - - This functions is intented to be overridden by sub-classes. - Returns: - `dict` of functions that take predictions and targets `Tensor` objects and - return `Tensor`. - """ - return {} - def _check_inputs(self, features, targets): if self._features_info is not None: if not tensor_signature.tensors_compatible(features, self._features_info): @@ -460,9 +447,7 @@ class BaseEstimator(sklearn.BaseEstimator): global_step = contrib_framework.create_global_step(g) features, targets = input_fn() self._check_inputs(features, targets) - eval_dict = self._get_eval_ops(features, targets, - metrics if metrics is not None else - self._get_default_metric_functions()) + eval_dict = self._get_eval_ops(features, targets, metrics) update_op, eval_dict = self._extract_metric_update_ops(eval_dict) eval_results, _ = evaluate(graph=g, output_dir=eval_dir, @@ -475,9 +460,13 @@ class BaseEstimator(sklearn.BaseEstimator): max_steps=steps) return eval_results - def _infer_model(self, - x=None, input_fn=None, feed_fn=None, - batch_size=None): + def _get_features_from_input_fn(self, input_fn): + result = input_fn() + if isinstance(result, (list, tuple)): + return result[0] + return result + + def _infer_model(self, x=None, input_fn=None, feed_fn=None, batch_size=None): # Converts inputs into tf.DataFrame / tf.Series. batch_size = -1 if batch_size is None else batch_size if x is not None: @@ -487,7 +476,7 @@ class BaseEstimator(sklearn.BaseEstimator): with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) contrib_framework.create_global_step(g) - features, _ = input_fn() + features = self._get_features_from_input_fn(input_fn) predictions = self._get_predict_ops(features) return_dict = True if not isinstance(predictions, dict): @@ -570,7 +559,8 @@ class Estimator(BaseEstimator): Returns: Numpy array of predicted classes or regression values. """ - predictions = self._infer_model(x=x, input_fn=input_fn, + predictions = self._infer_model(x=x, + input_fn=input_fn, batch_size=batch_size) if self._classification: for key in predictions: @@ -589,8 +579,7 @@ class Estimator(BaseEstimator): Returns: Numpy array of predicted probabilities. """ - return self._infer_model(x=x, input_fn=input_fn, - batch_size=batch_size) + return self._infer_model(x=x, input_fn=input_fn, batch_size=batch_size) def _get_train_ops(self, features, targets): """Method that builds model graph and returns trainer ops. @@ -645,6 +634,9 @@ class Estimator(BaseEstimator): """ predictions, loss = self._model_fn(features, targets, ModeKeys.EVAL) result = {'loss': loss} + if metrics is None: + metrics = _EVAL_METRICS[ + 'classification' if self._classification else 'regression'] if isinstance(targets, dict) and len(targets) == 1: # Unpack single target into just tensor. targets = targets[targets.keys()[0]] @@ -671,15 +663,6 @@ class Estimator(BaseEstimator): predictions, _ = self._model_fn(features, targets, ModeKeys.INFER) return predictions - def _get_default_metric_functions(self): - """Method that provides default metric operations. - - Returns: - a dictionary of metric operations. - """ - return _EVAL_METRICS[ - 'classification' if self._classification else 'regression'] - def _get_feature_ops_from_example(self, examples_batch): """Unimplemented.